#!/usr/bin/env bash
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
#           2022 burkliu(boji123@aliyun.com)
#           2023 gaoxinglong(gaoxinglong@xdf.cn)

. ./path.sh || exit 1;
export  LD_LIBRARY_PATH=/home/gaoxinglong/env/lib/cudnn_8.6.0/lib/:/home/gaoxinglong/env/bin/anaconda3/envs/wenet/lib/python3.8/site-packages/nvidia/cublas/lib/:${LD_LIBRARY_PATH}
# Use this to control how many gpu you use, It's 1-gpu training if you specify
# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch
gpus="0,1,2,3,4,5,6,7"
# gpus="0,1"
n_gpus=$(echo ${gpus} | tr "," "\n" | wc -l)
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
ulimit -n 
env_root=/home/gaoxinglong/env/tools/wenet/examples/multi_cn/s0

stage=0


local_exp=exp/train_unified_conformer_device/
# data file, such as wav.scp
wav_scp=${env_root}/data/test_u/wav.scp
wav_scp=${env_root}/data/dev_u/wav.scp
# wav_scp=/home/share_ssd_data/env/device_asr_data/2023-02-23/data/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-02-24/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-03-01/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-03-01/badcase_exp01/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-02-23-dccrn/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-02-24-dccrn/wav.scp
wav_scp=/home/share_ssd_data/env/device_asr_data/2023-03-01-dccrn/wav.scp
# wav_scp=/home/share_ssd_data/env/weimeng/workspace/keywords-test-230329/wav_annatation_done-record1_1679556067010.scp
wav_scp=/home/share_ssd_data/env/weimeng/workspace/keywords-test-230329/wav_annatation_done-record2_1679627040948.scp
wav_scp=/home/share_ssd_data/env/weimeng/workspace/keywords-test-230331/keywords-test-230331.scp
wav_scp=/home/share_ssd_data/env/weimeng/workspace/keywords-test-230411/keywords-test-230411.scp
# wav_scp=/home/gaoxinglong/env/data/device_asr_data/audio_badcase_1st/wav.scp
# data file, such as text
label_file=${env_root}/data/dev_u/text
# units file, such as units.txt
unit_file=${env_root}/data/lang_pinyin_en_all.txt
# decoder file, such as decoder.txt
model_raw_dir=${env_root}/${local_exp}
context_path=${env_root}/${local_exp}/context_file.txt
context_scores=6.0
output_dir=${env_root}/${local_exp}/test_ctc_2023-02-24-dccrn
output_dir=${env_root}/${local_exp}/test_ctc_2023-04-11-long-distance
mkdir -p ${output_dir}

decode_checkpoint=${model_raw_dir}/20.pt
# decode_checkpoint=${model_raw_dir}/26.pt
decode_checkpoint=${model_raw_dir}/35.pt
decode_checkpoint=${model_raw_dir}/44.pt
decode_checkpoint=${model_raw_dir}/53.pt
decode_checkpoint=${model_raw_dir}/72.pt
decode_checkpoint=${model_raw_dir}/177.pt


scp=${wav_scp}
label_file=${label_file}
model_file=${model_file}
unit_file=${unit_file}
dir=${output_dir}

export GLOG_logtostderr=1
export GLOG_v=2

set -e

ulimit -c unlimited

nj=1
chunk_size=16
ctc_weight=0.3
reverse_weight=0.0
rescoring_weight=1.0
decoding_chunk_size=16
# only used in rescore mode for weighting different scores
# rescore_ctc_weight=0.5
# rescore_transducer_weight=0.5
# rescore_attn_weight=0.5
# only used in beam search, either pure beam search mode OR beam search inside rescoring
search_ctc_weight=0.5
# search_transducer_weight=0.7
# For CTC WFST based decoding
fst_path=
dict_path=
acoustic_scale=1.0
# beam=15.0
beam=10.0
lattice_beam=12.0
min_active=200
max_active=7000
blank_skip_thresh=1.0
length_penalty=0.0

. tools/parse_options.sh || exit 1;

echo "scp: $scp"
echo "model_file: $model_file"
echo "unit_file: $unit_file"
echo "output_dir: $dir"
echo "label_file: $label_file"


mkdir -p $dir/split${nj}

# Step 1. Split wav.scp
# split_scps=""
# for n in $(seq ${nj}); do
#   split_scps="${split_scps} ${dir}/split${nj}/wav.${n}.scp"
# done
# tools/data/split_scp.pl ${scp} ${split_scps}


if [ ${stage} -le 0 ]; then
python3 wenet/bin/export_jit.py \
    --config ${model_raw_dir}/train.yaml \
    --checkpoint ${decode_checkpoint} \
    --output_file $model_raw_dir/final.zip \
    --output_quant_file $model_raw_dir/final_quant.zip

# python3 wenet/bin/export_onnx_cpu.py \
#     --config ${model_raw_dir}/train.yaml \
#     --checkpoint ${decode_checkpoint} \
#     --output_dir $model_raw_dir \
#     --chunk_size 16 --num_decoding_left_chunks 16 --reverse_weight 0 
model_file=${model_raw_dir}/final_quant.zip
fi
# exit 0

if [ ${stage} -le 1 ]; then
wfst_decode_opts=
if [ ! -z $fst_path ]; then
  wfst_decode_opts="--fst_path $fst_path"
  wfst_decode_opts="$wfst_decode_opts --beam $beam"
  wfst_decode_opts="$wfst_decode_opts --dict_path $dict_path"
  wfst_decode_opts="$wfst_decode_opts --lattice_beam $lattice_beam"
  wfst_decode_opts="$wfst_decode_opts --max_active $max_active"
  wfst_decode_opts="$wfst_decode_opts --min_active $min_active"
  wfst_decode_opts="$wfst_decode_opts --acoustic_scale $acoustic_scale"
  wfst_decode_opts="$wfst_decode_opts --blank_skip_thresh $blank_skip_thresh"
  wfst_decode_opts="$wfst_decode_opts --length_penalty $length_penalty"
  echo $wfst_decode_opts > $dir/config
fi

bin_recog_tool=/home/gaoxinglong/env/tools/wenet/runtime/libtorch/build/bin/decoder_main
     ${bin_recog_tool}   --rescoring_weight $rescoring_weight --context_path ${context_path} --context_score ${context_scores} \
     --ctc_weight $ctc_weight \
     --reverse_weight $reverse_weight \
     --chunk_size $chunk_size \
     --wav_scp ${wav_scp} \
     --model_path $model_file \
     --unit_path $unit_file \
     $wfst_decode_opts \
     --result ${dir}/result.text 2>&1 | tee -a ${dir}/result.log
wait
fi

exit 0

if [ ${stage} -le 2 ] ; then
    # ps aux | grep recog | awk '{print $2}' | xargs sudo kill -9
    test_data=data/dev_u/10k.data.list
    test_data_ref=data/dev_u/text
    echo "------------------------------ test_data: {test_data}"
    decode_modes="ctc_greedy_search ctc_prefix_beam_search"
    for mode in ${decode_modes}; do
    {
        test_dir=${model_raw_dir}/test_${mode}_chunk_${decoding_chunk_size}
        mkdir -p $test_dir
        python3 wenet/bin/recognize.py --gpu 0  --mode $mode \
            --config $model_raw_dir/train.yaml \
            --data_type raw \
            --test_data ${test_data} \
            --checkpoint $decode_checkpoint \
            --beam_size 10 \
            --batch_size 1 \
            --penalty 0.0 \
            --dict ${unit_file} \
            --ctc_weight $ctc_weight \
            --result_file $test_dir/text \
            ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
         python3 tools/compute-wer.py --char=1 --v=1 \
           ${test_data_ref} $test_dir/text > $test_dir/wer
    } &
    done
    wait
fi

tail $dir/*.log | grep RTF | awk '{sum+=$NF}END{print sum/NR}' > $dir/rtf





